import numpy as np
import sklearn
from sklearn.utils.validation import check_X_y
from sklearn.model_selection import KFold
from sklearn.kernel_approximation import Nystroem, RBFSampler
from model.rkhs.hyper_parameter import _BaseRKHSIV, _check_auto
from sklearn.preprocessing import RobustScaler as Scaler



class ApproxRKHSIV(_BaseRKHSIV):

    def __init__(self, kernel_approx='nystrom', n_components=25,
                    gamma_gm='auto', gamma_hq=0.1,
                    delta_scale='auto', delta_exp='auto', alpha_scale='auto'):
        """
        Parameters:
            kernel_approx : what approximator to use; either 'nystrom' or 'rbfsampler' (for kitchen sinks)
            n_components : how many approximation components to use
            # kernel : a pairwise kernel function or a string; similar interface with KernelRidge in sklearn
            gamma_hq : the gamma parameter for the kernel of h
            gamma_gm : the gamma parameter for the kernel of f
            delta_scale : the scale of the critical radius; delta_n = delta_scal / n**(delta_exp)
            delta_exp : the exponent of the cirical radius; delta_n = delta_scal / n**(delta_exp)
            alpha_scale : the scale of the regularization; alpha = alpha_scale * (delta**4)
        """
        self.kernel_approx = kernel_approx
        self.n_components = n_components
        self.gamma_gm = gamma_gm
        self.gamma_hq = gamma_hq 
        self.delta_scale = delta_scale  # worst-case critical value of RKHS spaces
        self.delta_exp = delta_exp
        self.alpha_scale = alpha_scale  # regularization strength from Theorem 5

    def _get_new_approx_instance(self, gamma):
        if self.kernel_approx == 'rbfsampler':
            return RBFSampler(gamma=gamma, n_components=self.n_components, random_state=1)
        elif self.kernel_approx == 'nystrom':
            return Nystroem(kernel='rbf', gamma=gamma, random_state=1, n_components=self.n_components)
        else:
            raise AttributeError("Invalid kernel approximator")

    def fit(self,AWX, model_target , AZX, type):
        if type == 'estimate_h':
            X = AWX
            condition = AZX
            y = model_target
        else:
            X = AZX
            condition = AWX
            y = model_target
        
        X, y = check_X_y(X, y, accept_sparse=True)
        condition, y = check_X_y(condition, y, accept_sparse=True)

        # Standardize condition and get gamma_gm -> RootKf
        condition = Scaler().fit_transform(condition)
        gamma_gm = self._get_gamma_gm(condition=condition)
        self.gamma_gm = gamma_gm
        self.featCond = self._get_new_approx_instance(gamma=self.gamma_gm)
        RootKf = self.featCond.fit_transform(condition)

        # Standardize X and get gamma_hqs -> RootKhs
        self.transX = Scaler()
        self.transX.fit(X)
        X = self.transX.transform(X)
        self.featX = self._get_new_approx_instance(gamma=self.gamma_hq)
        RootKh = self.featX.fit_transform(X)

        # delta & alpha
        n = X.shape[0]
        delta = self._get_delta(n)
        alpha = self._get_alpha(delta, self._get_alpha_scale())

        Q = np.linalg.pinv(RootKf.T @ RootKf /
                           (2 * n * delta**2) + np.eye(self.n_components) / 2)
        A = RootKh.T @ RootKf
        W = (A @ Q @ A.T + alpha * np.eye(self.n_components))
        B = A @ Q @ RootKf.T @ y
        # self.a = np.linalg.pinv(W) @ B
        self.a = np.linalg.lstsq(W, B, rcond=None)[0]
        self.fitted_delta = delta
        return self

    def predict(self, X):
        X = self.transX.transform(X)
        return self.featX.transform(X) @ self.a


class ApproxRKHSIVCV(ApproxRKHSIV):

    def __init__(self, kernel_approx='nystrom', n_components=25,
                    gamma_gm='auto', gamma_hqs = 'auto', n_gamma_hqs=10,
                    delta_scale='auto', delta_exp='auto', alpha_scales='auto', n_alphas=30, cv=6):
        """
        Parameters:
            kernel_approx : what approximator to use; either 'nystrom' or 'rbfsampler' (for kitchen sinks)
            n_components : how many nystrom components to use
            gamma_gm : the gamma parameter for the kernel of f
            gamma_hqs : the list of gamma parameters for kernel of h
            n_gamma_hqs : how many gamma_hqs to try
            delta_scale : the scale of the critical radius; delta_n = delta_scal / n**(delta_exp)
            delta_exp : the exponent of the cirical radius; delta_n = delta_scal / n**(delta_exp)
            alpha_scales : a list of scale of the regularization to choose from; alpha = alpha_scale * (delta**4)
            n_alphas : how mny alpha_scales to try
            cv : how many folds to use in cross-validation for alpha_scale
        """
        self.kernel_approx = kernel_approx
        self.n_components = n_components

        self.gamma_gm = gamma_gm
        self.gamma_hqs = gamma_hqs
        self.n_gamma_hqs=n_gamma_hqs

        self.delta_scale = delta_scale  # worst-case critical value of RKHS spaces
        self.delta_exp = delta_exp  # worst-case critical value of RKHS spaces
        self.alpha_scales = alpha_scales  # regularization strength from Theorem 5
        self.n_alphas = n_alphas
        self.cv = cv

    def _get_gamma_hqs(self,X):
        if _check_auto(self.gamma_hqs):
            params = {"squared": True}
            K_X_euclidean = sklearn.metrics.pairwise_distances(X = X, metric='euclidean', **params)
            #return 1./np.quantile(K_X_euclidean[np.tril_indices(X.shape[0],-1)], np.array(range(1, self.n_gamma_hqs))/self.n_gamma_hqs)/X.shape[1]
            return 1./np.quantile(K_X_euclidean[np.tril_indices(X.shape[0],-1)], np.array(range(1, self.n_gamma_hqs))/self.n_gamma_hqs)
        else:
            return self.gamma_hqs

    def fit(self, AWX, model_target, AZX, type):
        model_target = model_target.ravel() if model_target.ndim == 2 else model_target
        if type == 'estimate_h':
            X = AWX
            condition = AZX
            y = model_target
        if type == 'estimate_q':
            X = AZX
            condition = AWX
            y = model_target
    
        X, y = check_X_y(X, y, accept_sparse=True)
        condition, y = check_X_y(condition, y, accept_sparse=True)
        # Standardize condition and get gamma_gm -> RootKf
        condition = Scaler().fit_transform(condition)
        gamma_gm = self._get_gamma_gm(condition = condition)
        self.gamma_gm = gamma_gm
        self.featCond = self._get_new_approx_instance(gamma=gamma_gm)
        RootKf = self.featCond.fit_transform(condition)

        # Standardize X and get gamma_hqs -> RootKhs
        self.transX = Scaler()
        self.transX.fit(X)
        X = self.transX.transform(X)
        gamma_hqs = self._get_gamma_hqs(X)
        RootKhs = [self._get_new_approx_instance(gamma=gammah).fit_transform(X) for gammah in gamma_hqs]

        # delta & alpha
        n = X.shape[0]
        alpha_scales = self._get_alpha_scales()
        n_train = n * (self.cv - 1) / self.cv
        n_test = n / self.cv
        delta_train = self._get_delta(n_train)
        delta_test = self._get_delta(n_test)

        scores = []
        for it1, (train, test) in enumerate(KFold(n_splits=self.cv).split(X)):
            RootKf_train, RootKf_test = RootKf[train], RootKf[test]
            Q_train = np.linalg.pinv(
                RootKf_train.T @ RootKf_train / (2 * n_train * (delta_train**2)) + np.eye(self.n_components) / 2)
            Q_test = np.linalg.pinv(
                RootKf_test.T @ RootKf_test / (2 * n_test * (delta_test**2)) + np.eye(self.n_components) / 2)
            scores.append([])
            for it2, RootKh in enumerate(RootKhs):
                RootKh_train, RootKh_test = RootKh[train], RootKh[test]
                A_train = RootKh_train.T @ RootKf_train
                AQA_train = A_train @ Q_train @ A_train.T
                B_train = A_train @ Q_train @ RootKf_train.T @ y[train]
                scores[it1].append([])
                for alpha_scale in alpha_scales:
                    alpha = self._get_alpha(delta_train, alpha_scale)
                    # a = np.linalg.pinv(AQA_train + alpha *
                    #                    np.eye(self.n_components)) @ B_train
                    a = np.linalg.lstsq(AQA_train + alpha *
                                        np.eye(self.n_components), B_train, rcond=None)[0]
                    res = RootKf_test.T @ (y[test] - RootKh_test @ a)
                    scores[it1][it2].append((res.T @ Q_test @ res).reshape(-1)[0] / (len(test)**2))

        avg_scores = np.mean(np.array(scores), axis=0)
        best_ind = np.unravel_index(np.argmin(avg_scores), avg_scores.shape)

        self.gamma_hq = gamma_hqs[best_ind[0]]
        self.featX = self._get_new_approx_instance(gamma=self.gamma_hq)
        RootKh = self.featX.fit_transform(X)

        self.best_alpha_scale = alpha_scales[best_ind[1]]
        delta = self._get_delta(n)
        self.best_alpha = self._get_alpha(delta, self.best_alpha_scale)

        Q = np.linalg.pinv(RootKf.T @ RootKf /
                           (2 * n * delta**2) + np.eye(self.n_components) / 2)
        A = RootKh.T @ RootKf
        W = (A @ Q @ A.T + self.best_alpha * np.eye(self.n_components))
        B = A @ Q @ RootKf.T @ y
        # self.a = np.linalg.pinv(W) @ B
        self.a = np.linalg.lstsq(W, B, rcond=None)[0]
        self.fitted_delta = delta
        return self